{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ItXfxkxvosLH"
   },
   "source": [
    "# TensorFlow and TextAttack"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/QData/TextAttack/blob/master/docs/2notebook/Example_0_tensorflow.ipynb)\n",
    "\n",
    "[![View Source on GitHub](https://img.shields.io/badge/github-view%20source-black.svg)](https://github.com/QData/TextAttack/blob/master/docs/2notebook/Example_0_tensorflow.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WooZ9pGnNJbv"
   },
   "source": [
    "## Training\n",
    "\n",
    "\n",
    "\n",
    "The following is code for training a text classification model using TensorFlow (and on top of it, the Keras API). This comes from the Tensorflow documentation ([see here](https://www.tensorflow.org/tutorials/keras/text_classification_with_hub)).\n",
    "\n",
    "This cell loads the IMDB dataset (using `tensorflow_datasets`, not `datasets`), initializes a simple classifier, and trains it using Keras."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "2ew7HTbPpCJH",
    "outputId": "1c1711e1-cf82-4b09-899f-db7c9bb68513"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:absl:No config specified, defaulting to first: imdb_reviews/plain_text\n",
      "INFO:absl:Overwrite dataset info from restored data version.\n",
      "INFO:absl:Reusing dataset imdb_reviews (/root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0)\n",
      "INFO:absl:Constructing tf.data.Dataset for split ['train', 'test'], from /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Version:  2.2.0\n",
      "Eager mode:  True\n",
      "Hub version:  0.8.0\n",
      "GPU is NOT AVAILABLE\n",
      "Model: \"sequential_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "keras_layer_1 (KerasLayer)   (None, 20)                400020    \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 16)                336       \n",
      "_________________________________________________________________\n",
      "dense_3 (Dense)              (None, 1)                 17        \n",
      "=================================================================\n",
      "Total params: 400,373\n",
      "Trainable params: 400,373\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Epoch 1/40\n",
      "30/30 [==============================] - 2s 75ms/step - loss: 0.6652 - accuracy: 0.5760 - val_loss: 0.6214 - val_accuracy: 0.6253\n",
      "Epoch 2/40\n",
      "30/30 [==============================] - 2s 72ms/step - loss: 0.5972 - accuracy: 0.6523 - val_loss: 0.5783 - val_accuracy: 0.6646\n",
      "Epoch 3/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.5533 - accuracy: 0.6951 - val_loss: 0.5424 - val_accuracy: 0.7026\n",
      "Epoch 4/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.5126 - accuracy: 0.7319 - val_loss: 0.5082 - val_accuracy: 0.7335\n",
      "Epoch 5/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.4739 - accuracy: 0.7641 - val_loss: 0.4763 - val_accuracy: 0.7590\n",
      "Epoch 6/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.4385 - accuracy: 0.7911 - val_loss: 0.4478 - val_accuracy: 0.7828\n",
      "Epoch 7/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.4038 - accuracy: 0.8133 - val_loss: 0.4227 - val_accuracy: 0.7892\n",
      "Epoch 8/40\n",
      "30/30 [==============================] - 2s 72ms/step - loss: 0.3712 - accuracy: 0.8327 - val_loss: 0.3987 - val_accuracy: 0.8119\n",
      "Epoch 9/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.3416 - accuracy: 0.8504 - val_loss: 0.3784 - val_accuracy: 0.8234\n",
      "Epoch 10/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.3162 - accuracy: 0.8623 - val_loss: 0.3619 - val_accuracy: 0.8410\n",
      "Epoch 11/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.2914 - accuracy: 0.8761 - val_loss: 0.3476 - val_accuracy: 0.8471\n",
      "Epoch 12/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.2705 - accuracy: 0.8869 - val_loss: 0.3367 - val_accuracy: 0.8512\n",
      "Epoch 13/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.2518 - accuracy: 0.8956 - val_loss: 0.3288 - val_accuracy: 0.8495\n",
      "Epoch 14/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.2351 - accuracy: 0.9043 - val_loss: 0.3208 - val_accuracy: 0.8591\n",
      "Epoch 15/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.2193 - accuracy: 0.9133 - val_loss: 0.3156 - val_accuracy: 0.8590\n",
      "Epoch 16/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.2050 - accuracy: 0.9202 - val_loss: 0.3112 - val_accuracy: 0.8651\n",
      "Epoch 17/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.1923 - accuracy: 0.9276 - val_loss: 0.3114 - val_accuracy: 0.8580\n",
      "Epoch 18/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.1814 - accuracy: 0.9303 - val_loss: 0.3069 - val_accuracy: 0.8677\n",
      "Epoch 19/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.1696 - accuracy: 0.9370 - val_loss: 0.3067 - val_accuracy: 0.8663\n",
      "Epoch 20/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.1594 - accuracy: 0.9419 - val_loss: 0.3091 - val_accuracy: 0.8634\n",
      "Epoch 21/40\n",
      "30/30 [==============================] - 2s 74ms/step - loss: 0.1495 - accuracy: 0.9439 - val_loss: 0.3066 - val_accuracy: 0.8748\n",
      "Epoch 22/40\n",
      "30/30 [==============================] - 2s 75ms/step - loss: 0.1403 - accuracy: 0.9502 - val_loss: 0.3075 - val_accuracy: 0.8706\n",
      "Epoch 23/40\n",
      "30/30 [==============================] - 2s 73ms/step - loss: 0.1323 - accuracy: 0.9539 - val_loss: 0.3114 - val_accuracy: 0.8680\n",
      "Epoch 24/40\n",
      "30/30 [==============================] - 2s 73ms/step - loss: 0.1232 - accuracy: 0.9578 - val_loss: 0.3126 - val_accuracy: 0.8716\n",
      "Epoch 25/40\n",
      "30/30 [==============================] - 2s 72ms/step - loss: 0.1157 - accuracy: 0.9604 - val_loss: 0.3158 - val_accuracy: 0.8710\n",
      "Epoch 26/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.1090 - accuracy: 0.9630 - val_loss: 0.3181 - val_accuracy: 0.8725\n",
      "Epoch 27/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.1017 - accuracy: 0.9665 - val_loss: 0.3234 - val_accuracy: 0.8697\n",
      "Epoch 28/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.0954 - accuracy: 0.9697 - val_loss: 0.3291 - val_accuracy: 0.8686\n",
      "Epoch 29/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.0894 - accuracy: 0.9720 - val_loss: 0.3305 - val_accuracy: 0.8717\n",
      "Epoch 30/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.0833 - accuracy: 0.9753 - val_loss: 0.3362 - val_accuracy: 0.8723\n",
      "Epoch 31/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.0776 - accuracy: 0.9771 - val_loss: 0.3422 - val_accuracy: 0.8721\n",
      "Epoch 32/40\n",
      "30/30 [==============================] - 2s 71ms/step - loss: 0.0726 - accuracy: 0.9798 - val_loss: 0.3484 - val_accuracy: 0.8744\n",
      "Epoch 33/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.0678 - accuracy: 0.9825 - val_loss: 0.3538 - val_accuracy: 0.8722\n",
      "Epoch 34/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.0631 - accuracy: 0.9837 - val_loss: 0.3616 - val_accuracy: 0.8736\n",
      "Epoch 35/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.0586 - accuracy: 0.9861 - val_loss: 0.3680 - val_accuracy: 0.8724\n",
      "Epoch 36/40\n",
      "30/30 [==============================] - 2s 69ms/step - loss: 0.0550 - accuracy: 0.9875 - val_loss: 0.3772 - val_accuracy: 0.8742\n",
      "Epoch 37/40\n",
      "30/30 [==============================] - 2s 69ms/step - loss: 0.0506 - accuracy: 0.9887 - val_loss: 0.3821 - val_accuracy: 0.8709\n",
      "Epoch 38/40\n",
      "30/30 [==============================] - 2s 70ms/step - loss: 0.0471 - accuracy: 0.9901 - val_loss: 0.3907 - val_accuracy: 0.8692\n",
      "Epoch 39/40\n",
      "30/30 [==============================] - 2s 68ms/step - loss: 0.0436 - accuracy: 0.9914 - val_loss: 0.3980 - val_accuracy: 0.8703\n",
      "Epoch 40/40\n",
      "30/30 [==============================] - 2s 69ms/step - loss: 0.0405 - accuracy: 0.9922 - val_loss: 0.4070 - val_accuracy: 0.8699\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_hub as hub\n",
    "import tensorflow_datasets as tfds\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "print(\"Version: \", tf.__version__)\n",
    "print(\"Eager mode: \", tf.executing_eagerly())\n",
    "print(\"Hub version: \", hub.__version__)\n",
    "print(\"GPU is\", \"available\" if tf.config.list_physical_devices('GPU') else \"NOT AVAILABLE\")\n",
    "\n",
    "train_data, test_data = tfds.load(name=\"imdb_reviews\", split=[\"train\", \"test\"], \n",
    "                                  batch_size=-1, as_supervised=True)\n",
    "\n",
    "train_examples, train_labels = tfds.as_numpy(train_data)\n",
    "test_examples, test_labels = tfds.as_numpy(test_data)\n",
    "\n",
    "model = \"https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1\"\n",
    "hub_layer = hub.KerasLayer(model, output_shape=[20], input_shape=[], \n",
    "                           dtype=tf.string, trainable=True)\n",
    "hub_layer(train_examples[:3])\n",
    "\n",
    "model = tf.keras.Sequential()\n",
    "model.add(hub_layer)\n",
    "model.add(tf.keras.layers.Dense(16, activation='relu'))\n",
    "model.add(tf.keras.layers.Dense(1))\n",
    "\n",
    "model.summary()\n",
    "\n",
    "x_val = train_examples[:10000]\n",
    "partial_x_train = train_examples[10000:]\n",
    "\n",
    "y_val = train_labels[:10000]\n",
    "partial_y_train = train_labels[10000:]\n",
    "\n",
    "model.compile(optimizer='adam',\n",
    "              loss=tf.losses.BinaryCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "history = model.fit(partial_x_train,\n",
    "                    partial_y_train,\n",
    "                    epochs=40,\n",
    "                    batch_size=512,\n",
    "                    validation_data=(x_val, y_val),\n",
    "                    verbose=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "3varlQvrnHqV"
   },
   "source": [
    "## Attacking\n",
    "\n",
    "For each input, our classifier outputs a single number that indicates how positive or negative the model finds the input. For binary classification, TextAttack expects two numbers for each input (a score for each class, positive and negative). We have to post-process each output to fit this TextAttack format. To add this post-processing we need to implement a custom model wrapper class (instead of using the built-in `textattack.models.wrappers.TensorFlowModelWrapper`).\n",
    "\n",
    "Each `ModelWrapper` must implement a single method, `__call__`, which takes a list of strings and returns a `List`, `np.ndarray`, or `torch.Tensor` of predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "fHX3Lo7wU2LM"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from textattack.models.wrappers import ModelWrapper\n",
    "\n",
    "class CustomTensorFlowModelWrapper(ModelWrapper):\n",
    "    def __init__(self, model):\n",
    "        self.model = model\n",
    "\n",
    "    def __call__(self, text_input_list):\n",
    "        text_array = np.array(text_input_list)\n",
    "        preds = self.model(text_array).numpy()\n",
    "        logits = torch.exp(-torch.tensor(preds))\n",
    "        logits = 1 / (1 + logits)\n",
    "        logits = logits.squeeze(dim=-1)\n",
    "        # Since this model only has a single output (between 0 or 1),\n",
    "        # we have to add the second dimension.\n",
    "        final_preds = torch.stack((1-logits, logits), dim=1)\n",
    "        return final_preds\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Ku71HuZ4n7ih"
   },
   "source": [
    "Let's test our model wrapper out to make sure it can use our model to return predictions in the correct format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 52
    },
    "colab_type": "code",
    "id": "9hgiLQC4ejmM",
    "outputId": "132c3be5-fe5e-4be4-ef98-5c2efedc0dfd"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.2745, 0.7255],\n",
       "        [0.0072, 0.9928]])"
      ]
     },
     "execution_count": 14,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "CustomTensorFlowModelWrapper(model)(['I hate you so much', 'I love you'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-Bs14Hr4n_Sp"
   },
   "source": [
    "Looks good! Now we can initialize our model wrapper with the model we trained and pass it to an instance of `textattack.attack.Attack`. \n",
    "\n",
    "We'll use the `PWWSRen2019` recipe as our attack, and attack 10 samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 780
    },
    "colab_type": "code",
    "id": "07mOE-wLVQDR",
    "outputId": "e47a099e-c0f6-4c21-8e52-1a437741bc16"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:nlp.builder:Using custom data configuration default\n",
      "\u001b[34;1mtextattack\u001b[0m: Loading \u001b[94mnlp\u001b[0m dataset \u001b[94mrotten_tomatoes\u001b[0m, split \u001b[94mtest\u001b[0m.\n",
      "\u001b[34;1mtextattack\u001b[0m: Unknown if model of class <class '__main__.CustomTensorFlowModelWrapper'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92mPositive (60%)\u001b[0m --> \u001b[37m[SKIPPED]\u001b[0m\n",
      "\n",
      "kaufman's script is never especially clever and often is rather pretentious .\n",
      "\u001b[91mNegative (98%)\u001b[0m --> \u001b[92mPositive (59%)\u001b[0m\n",
      "\n",
      "an \u001b[91munfortunate\u001b[0m title for a film that has \u001b[91mnothing\u001b[0m endearing about it .\n",
      "\n",
      "an \u001b[92minauspicious\u001b[0m title for a film that has \u001b[92mzip\u001b[0m endearing about it .\n",
      "\u001b[91mNegative (73%)\u001b[0m --> \u001b[92mPositive (59%)\u001b[0m\n",
      "\n",
      "sade achieves the near-impossible : it \u001b[91mturns\u001b[0m the marquis de sade into a dullard .\n",
      "\n",
      "sade achieves the near-impossible : it \u001b[92mtour\u001b[0m the marquis de sade into a dullard .\n",
      "\u001b[91mNegative (98%)\u001b[0m --> \u001b[37m[SKIPPED]\u001b[0m\n",
      "\n",
      ". . . planos fijos , tomas largas , un ritmo pausado y una sutil observación de sus personajes , sin estridencias ni grandes revelaciones .\n",
      "\u001b[91mNegative (97%)\u001b[0m --> \u001b[92mPositive (62%)\u001b[0m\n",
      "\n",
      "charly comes off as emotionally manipulative and \u001b[91msadly\u001b[0m imitative of innumerable past love story derisions .\n",
      "\n",
      "charly comes off as emotionally manipulative and \u001b[92mdeplorably\u001b[0m imitative of innumerable past love story derisions .\n",
      "\u001b[91mNegative (70%)\u001b[0m --> \u001b[92mPositive (93%)\u001b[0m\n",
      "\n",
      "any intellectual \u001b[91marguments\u001b[0m being made about the nature of god are framed in a drama so clumsy , there is a real danger less sophisticated audiences will mistake it for an endorsement of the very things that bean abhors .\n",
      "\n",
      "any intellectual \u001b[92mcontention\u001b[0m being made about the nature of god are framed in a drama so clumsy , there is a real danger less sophisticated audiences will mistake it for an endorsement of the very things that bean abhors .\n",
      "\u001b[92mPositive (97%)\u001b[0m --> \u001b[37m[SKIPPED]\u001b[0m\n",
      "\n",
      "a handsome but unfulfilling suspense drama more suited to a quiet evening on pbs than a night out at an amc .\n",
      "\u001b[91mNegative (93%)\u001b[0m --> \u001b[37m[SKIPPED]\u001b[0m\n",
      "\n",
      "you will likely prefer to keep on watching .\n",
      "\u001b[91mNegative (100%)\u001b[0m --> \u001b[92mPositive (74%)\u001b[0m\n",
      "\n",
      "what ensues are \u001b[91mmuch\u001b[0m blood-splattering , \u001b[91mmass\u001b[0m drug-induced \u001b[91mbowel\u001b[0m evacuations , and none-too-funny commentary on the cultural \u001b[91mdistinctions\u001b[0m between americans and \u001b[91mbrits\u001b[0m .\n",
      "\n",
      "what ensues are \u001b[92mlots\u001b[0m blood-splattering , \u001b[92mplenty\u001b[0m drug-induced \u001b[92mintestine\u001b[0m evacuations , and none-too-funny commentary on the cultural \u001b[92mdistinction\u001b[0m between americans and \u001b[92mBrits\u001b[0m .\n",
      "\u001b[92mPositive (100%)\u001b[0m --> \u001b[37m[SKIPPED]\u001b[0m\n",
      "\n",
      "a film without surprise geared toward maximum comfort and familiarity .\n"
     ]
    }
   ],
   "source": [
    "model_wrapper = CustomTensorFlowModelWrapper(model)\n",
    "\n",
    "from textattack.datasets import HuggingFaceDataset\n",
    "from textattack.attack_recipes import PWWSRen2019\n",
    "\n",
    "dataset = HuggingFaceDataset(\"rotten_tomatoes\", None, \"test\", shuffle=True)\n",
    "attack = PWWSRen2019.build(model_wrapper)\n",
    "\n",
    "results_iterable = attack.attack_dataset(dataset, indices=range(10))\n",
    "for result in results_iterable:\n",
    "  print(result.__str__(color_method='ansi'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "P3L9ccqGoS-J"
   },
   "source": [
    "## Conclusion \n",
    "\n",
    "Looks good! We successfully loaded a model, adapted it for TextAttack's `ModelWrapper`, and used that object in an attack. This is basically how you would adapt any model, using TensorFlow or any other library, for use with TextAttack."
   ]
  }
 ],
 "metadata": {
  "accelerator": "TPU",
  "colab": {
   "collapsed_sections": [],
   "name": "[TextAttack] tensorflow/keras example",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
